{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from dataset.Dataset import TGS_Dataset, do_horizontal_flip\n",
    "from model.Models import UnetSeNext50\n",
    "from losses.Evaluation import do_length_decode, do_length_encode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_net_and_predict(net, test_path, load_paths, batch_size=32, tta_transform=None, threshold=0.5, min_size=0):\n",
    "    sum_ = np.zeros((18000, 1, 101, 101))\n",
    "    test_dataset = TGS_Dataset(test_path)\n",
    "    test_loader, test_ids = test_dataset.yield_dataloader(data='test', num_workers=11, batch_size=batch_size)\n",
    "    # predict\n",
    "    for i in tqdm(range(len(load_paths))):\n",
    "        net.load_model(load_paths[i])\n",
    "        p = net.do_predict(test_loader, threshold=0, tta_transform=tta_transform)\n",
    "        sum_ += p['pred']\n",
    "    del test_dataset, test_loader\n",
    "    \n",
    "    return sum_, p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tta_transform(images, mode):\n",
    "    out = []\n",
    "    if mode == 'out':\n",
    "        images = images[0]\n",
    "    images = images.transpose((0, 2, 3, 1))\n",
    "    tta = []\n",
    "    for i in range(len(images)):\n",
    "        t = np.fliplr(images[i])\n",
    "        tta.append(t)\n",
    "    tta = np.transpose(tta, (0, 3, 1, 2))\n",
    "    out.append(tta)\n",
    "    return np.asarray(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_PATH = './test/'\n",
    "net = UnetSeNext50()\n",
    "NET_NAME = type(net).__name__\n",
    "THRESHOLD = 0.45\n",
    "MIN_SIZE = 0\n",
    "BATCH_SIZE = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/2 [00:00<?, ?it/s]/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1890: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n",
      "  warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n",
      "100%|██████████| 2/2 [09:08<00:00, 270.51s/it]\n"
     ]
    }
   ],
   "source": [
    "filelists = ['./Saves/UnetSeNext50/2018-10-20 03:57_Fold1_Epoach12_reset0_val0.847',\n",
    "            './Saves/UnetSeNext50/2018-10-20 04:24_Fold1_Epoach30_reset1_val0.864']\n",
    "sum_, p = load_net_and_predict(net, TEST_PATH, filelists,\n",
    "                            tta_transform=tta_transform,\n",
    "                            batch_size=BATCH_SIZE,\n",
    "                            threshold=THRESHOLD,min_size=MIN_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg = sum_ / (len(filelists))\n",
    "pred = avg > THRESHOLD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "rle = []\n",
    "for i in range(len(pred)):\n",
    "    rle.append(do_length_encode(pred[i]))\n",
    "# create sub\n",
    "df = pd.DataFrame(dict(id=p['id'], rle_mask=rle))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can modify the filename of the .csv file\n",
    "df.to_csv(os.path.join(\n",
    "        './results/',\n",
    "        'submit.csv'),\n",
    "        index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
